{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\l\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# diroot_path=r'/W9_object_detection/'stributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# ==============================================================================\n",
    "\n",
    "r\"\"\"Convert the Oxford pet dataset to TFRecord for object_detection.\n",
    "\n",
    "See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar\n",
    "     Cats and Dogs\n",
    "     IEEE Conference on Computer Vision and Pattern Recognition, 2012\n",
    "     http://www.robots.ox.ac.uk/~vgg/data/pets/\n",
    "\n",
    "Example usage:\n",
    "    python object_detection/dataset_tools/create_pet_tf_record.py \\\n",
    "        --data_dir=/home/user/pet \\\n",
    "        --output_dir=/home/user/pet/output\n",
    "\"\"\"\n",
    "import sys,os\n",
    "root_path = \"C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection\" \n",
    "sys.path.append(root_path)\n",
    "\n",
    "import hashlib\n",
    "import io\n",
    "import logging\n",
    "import random\n",
    "import re\n",
    "\n",
    "from lxml import etree\n",
    "import numpy as np\n",
    "import PIL.Image\n",
    "import tensorflow as tf\n",
    "\n",
    "from object_detection.utils import dataset_util\n",
    "from object_detection.utils import label_map_util\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path1 = \"C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/\"  \n",
    "data_path = os.path.join(path1,'object_detection/dataset_tools')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "path1 = r\"C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/\"  \n",
    "data_path = os.path.join(path1,'object_detection/dataset_tools/')\n",
    "\n",
    "flags = tf.app.flags\n",
    "flags.DEFINE_string('new5_data_dir', os.path.join(data_path,'data/'), 'Root directory to raw pet dataset.')\n",
    "flags.DEFINE_string('new5_output_dir', os.path.join(data_path,'data/out/'), 'Path to directory to output TFRecords.')\n",
    "flags.DEFINE_string('new5_label_map_path', os.path.join('data/labels_items.txt'),'Path to label map proto')\n",
    "FLAGS = flags.FLAGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/\n",
      "{'computer': 1, 'monitor': 2, 'scuttlebutt': 3, 'water dispenser': 4, 'drawer chest': 5}\n",
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/images/\n",
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/annotations/\n",
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/annotations/xmls/\n"
     ]
    }
   ],
   "source": [
    "data_dir = FLAGS.new5_data_dir\n",
    "label_map_dict = label_map_util.get_label_map_dict(FLAGS.new5_label_map_path)\n",
    "\n",
    "logging.info('Reading from Pet dataset.')\n",
    "image_dir = os.path.join(data_dir, 'images/')\n",
    "annotations_dir = os.path.join(data_dir, 'annotations/')\n",
    "examples_path = os.path.join(annotations_dir, 'xmls/')\n",
    "\n",
    "print(data_dir)\n",
    "print(label_map_dict)\n",
    "print(image_dir)\n",
    "print(annotations_dir)\n",
    "print(examples_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['0000.xml', '0001.xml', '0002.xml', '0003.xml', '0004.xml']\n"
     ]
    }
   ],
   "source": [
    "#examples_list=[]\n",
    "#for file in os.listdir(examples_path):\n",
    "#    path = os.path.join(examples_path,str(file)) \n",
    "#   with tf.gfile.GFile(path) as fid:\n",
    "#       lines = fid.readlines()\n",
    "#    example_list = [line.strip().split(' ')[0] for line in lines]\n",
    "#    examples_list.append(example_list)\n",
    "\n",
    "# examples_list = dataset_util.read_examples_list(examples_path)\n",
    "\n",
    "examples_list = os.listdir(examples_path)\n",
    "\n",
    "print(examples_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "random.shuffle(examples_list)\n",
    "num_examples = len(examples_list)\n",
    "num_train = int(0.7 * num_examples)\n",
    "train_examples = examples_list[:num_train]\n",
    "val_examples = examples_list[num_train:]\n",
    "logging.info('%d training and %d validation examples.',\n",
    "           len(train_examples), len(val_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/out/pet_train.record\n"
     ]
    }
   ],
   "source": [
    "train_output_path = os.path.join(FLAGS.new5_output_dir, 'pet_train.record')\n",
    "val_output_path = os.path.join(FLAGS.new5_output_dir, 'pet_val.record')\n",
    "print(train_output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ERROR! Session/line number was not unique in database. History logging moved to new session 757\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection\\object_detection\\utils\\dataset_util.py:75: FutureWarning: The behavior of this method will change in future versions. Use specific 'len(elem)' or 'elem is not None' test instead.\n",
      "  if not xml:\n"
     ]
    }
   ],
   "source": [
    "# create_tf_record(train_output_path, label_map_dict, annotations_dir,image_dir, train_examples)\n",
    "# def create_tf_record(output_filename, label_map_dict,annotations_dir,image_dir,examples):\n",
    "\n",
    "#writer = tf.python_io.TFRecordWriter(train_output_path)\n",
    "for idx, example in enumerate(train_examples):\n",
    "    if idx % 100 == 0:\n",
    "        logging.info('On image %d of %d', idx, len(train_examples))\n",
    "    xml_path = os.path.join(annotations_dir,'xmls/',str(example))\n",
    "    \n",
    "    if not os.path.exists(xml_path):\n",
    "        logging.warning('Could not find %s, ignoring example.', xml_path)\n",
    "        continue\n",
    "\n",
    "    with tf.gfile.GFile(xml_path, 'r') as fid:\n",
    "        xml_str = fid.read()\n",
    "    xml = etree.fromstring(xml_str)\n",
    "    data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']\n",
    "        \n",
    "    #try:\n",
    "        #tf_example = dict_to_tf_example(data, label_map_dict, image_dir)\n",
    "        #writer.write(tf_example.SerializeToString())\n",
    "    #except ValueError:\n",
    "        #logging.warning('Invalid example: %s, ignoring.', xml_path)\n",
    "#writer.close()\n",
    "\n",
    "#tf_example = dict_to_tf_example(data, label_map_dict, image_dir)\n",
    "#def dict_to_tf_example(data,label_map_dict,image_subdirectory,ignore_difficult_instances=False):\n",
    "    ignore_difficult_instances=False\n",
    "    img_path = os.path.join(image_dir, data['filename'])  \n",
    "    with tf.gfile.GFile(img_path, 'rb') as fid:\n",
    "        encoded_jpg = fid.read()\n",
    "    encoded_jpg_io = io.BytesIO(encoded_jpg)\n",
    "    image = PIL.Image.open(encoded_jpg_io)\n",
    "   \n",
    "    if image.format != 'JPEG':\n",
    "        raise ValueError('Image format not JPEG')\n",
    "    key = hashlib.sha256(encoded_jpg).hexdigest()\n",
    "\n",
    "    width = int(data['size']['width'])\n",
    "    height = int(data['size']['height'])\n",
    "\n",
    "    classes = []\n",
    "    classes_text = []\n",
    "    truncated = []\n",
    "    poses = []\n",
    "    difficult_obj = []\n",
    "    \n",
    "    for obj in data['object']:\n",
    "        difficult = bool(int(obj['difficult']))\n",
    "        if ignore_difficult_instances and difficult:\n",
    "            continue\n",
    "        difficult_obj.append(int(difficult))\n",
    "        \n",
    "        class_name = obj['name']   \n",
    "        classes_text.append(class_name.encode('utf8'))\n",
    "        classes.append(label_map_dict[class_name])\n",
    "        truncated.append(int(obj['truncated']))\n",
    "        poses.append(obj['pose'].encode('utf8'))\n",
    " \n",
    "    feature_dict = {\n",
    "      'image/height': dataset_util.int64_feature(height),\n",
    "      'image/width': dataset_util.int64_feature(width),\n",
    "      'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),\n",
    "      'image/encoded': dataset_util.bytes_feature(encoded_jpg),\n",
    "      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),\n",
    "      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),\n",
    "      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),\n",
    "      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),\n",
    "      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),\n",
    "      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),\n",
    "      'image/object/class/label': dataset_util.int64_list_feature(classes),\n",
    "      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),\n",
    "      'image/object/truncated': dataset_util.int64_list_feature(truncated),\n",
    "      'image/object/view': dataset_util.bytes_list_feature(poses),\n",
    "    }\n",
    "\n",
    "    examples = tf.train.Example(features=tf.train.Features(feature=feature_dict))  \n",
    "    #writer.write(examples.SerializeToString())\n",
    "#writer.close() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 207,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C:/Users/l/Desktop/CSDN培训资料/第9轴_卷积神经网络分类任务和检测任务/W9_object_detection/object_detection/dataset_tools/data/out/a_train.tfrecords\n"
     ]
    }
   ],
   "source": [
    "a = os.path.join(FLAGS.new5_output_dir, 'a_train.tfrecords')\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = tf.python_io.TFRecordWriter(a)\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = tf.python_io.TFRecordWriter(train_output_path)\n",
    "writer.write(examples.SerializeToString())\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dict_to_tf_example(data,\n",
    "                       label_map_dict,\n",
    "                       image_subdirectory,\n",
    "                       ignore_difficult_instances=False):\n",
    "\n",
    "    # 图片路径\n",
    "    img_path = os.path.join(image_subdirectory, data['filename'])  \n",
    "    with tf.gfile.GFile(img_path, 'rb') as fid:\n",
    "        encoded_jpg = fid.read()\n",
    "    encoded_jpg_io = io.BytesIO(encoded_jpg)\n",
    "    image = PIL.Image.open(encoded_jpg_io)\n",
    "    \n",
    "   # 读取图片信息\n",
    "    if image.format != 'JPEG':\n",
    "        raise ValueError('Image format not JPEG')\n",
    "    key = hashlib.sha256(encoded_jpg).hexdigest()\n",
    "\n",
    "    width = int(data['size']['width'])\n",
    "    height = int(data['size']['height'])\n",
    "\n",
    "    xmins = []\n",
    "    ymins = []\n",
    "    xmaxs = []\n",
    "    ymaxs = []\n",
    "    classes = []\n",
    "    classes_text = []\n",
    "    truncated = []\n",
    "    poses = []\n",
    "    difficult_obj = []\n",
    " \n",
    "    for obj in data['object']:\n",
    "        difficult = bool(int(obj['difficult']))\n",
    "        if ignore_difficult_instances and difficult:\n",
    "            continue\n",
    "        difficult_obj.append(int(difficult))\n",
    "        \n",
    "        xmin = float(np.min(nonzero_x_indices))\n",
    "        xmax = float(np.max(nonzero_x_indices))\n",
    "        ymin = float(np.min(nonzero_y_indices))\n",
    "        ymax = float(np.max(nonzero_y_indices))\n",
    "\n",
    "        xmins.append(xmin / width)\n",
    "        ymins.append(ymin / height)\n",
    "        xmaxs.append(xmax / width)\n",
    "        ymaxs.append(ymax / height)\n",
    "        class_name = get_class_name_from_filename(data['filename'])\n",
    "        classes_text.append(class_name.encode('utf8'))\n",
    "        classes.append(label_map_dict[class_name])\n",
    "        truncated.append(int(obj['truncated']))\n",
    "        poses.append(obj['pose'].encode('utf8'))\n",
    "\n",
    "    # 将解析出的信息写成TFRecord\n",
    "    feature_dict = {\n",
    "      'image/height': dataset_util.int64_feature(height),\n",
    "      'image/width': dataset_util.int64_feature(width),\n",
    "      'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),\n",
    "      'image/encoded': dataset_util.bytes_feature(encoded_jpg),\n",
    "      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),\n",
    "      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),\n",
    "      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),\n",
    "      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),\n",
    "      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),\n",
    "      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),\n",
    "      'image/object/class/label': dataset_util.int64_list_feature(classes),\n",
    "      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),\n",
    "      'image/object/truncated': dataset_util.int64_list_feature(truncated),\n",
    "      'image/object/view': dataset_util.bytes_list_feature(poses),\n",
    "    }\n",
    "\n",
    "    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))\n",
    "    return example\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "data_path = os.path.join(root_path,'object_detection/dataset_tools/')\n",
    "\n",
    "# 定义变量 flags （name变量名称，default默认值，describe变量描述）\n",
    "flags = tf.app.flags\n",
    "flags.DEFINE_string('data_dir', os.path.join(data_path,'data'), 'Root directory to raw pet dataset.')\n",
    "flags.DEFINE_string('output_dir', os.path.join(data_path,'data/out'), 'Path to directory to output TFRecords.')\n",
    "flags.DEFINE_string('label_map_path', os.path.join('data/labels_items.txt'),'Path to label map proto')\n",
    "FLAGS = flags.FLAGS\n",
    "\n",
    "\n",
    "def dict_to_tf_example(data,\n",
    "                       label_map_dict,\n",
    "                       image_subdirectory,\n",
    "                       ignore_difficult_instances=False):\n",
    "    \"\"\"\n",
    "      Convert XML derived dict to tf.Example proto.\n",
    "\n",
    "      Notice that this function normalizes the bounding box coordinates provided\n",
    "      by the raw data.\n",
    "\n",
    "      Args:\n",
    "        data: dict holding PASCAL XML fields for a single image (obtained by\n",
    "          running dataset_util.recursive_parse_xml_to_dict)\n",
    "        mask_path: String path to PNG encoded mask.\n",
    "        label_map_dict: A map from string label names to integers ids.\n",
    "        image_subdirectory: String specifying subdirectory within the\n",
    "          Pascal dataset directory holding the actual image data.\n",
    "        ignore_difficult_instances: Whether to skip difficult instances in the\n",
    "          dataset  (default: False).\n",
    "        faces_only: If True, generates bounding boxes for pet faces.  Otherwise\n",
    "          generates bounding boxes (as well as segmentations for full pet bodies).\n",
    "\n",
    "      Returns:\n",
    "        example: The converted tf.Example.\n",
    "\n",
    "      Raises:\n",
    "        ValueError: if the image pointed to by data['filename'] is not a valid JPEG\n",
    "    \"\"\"\n",
    "\n",
    "    # 图片路径\n",
    "    img_path = os.path.join(image_subdirectory, data['filename']) # os.path.join():拼接路径\n",
    "    with tf.gfile.GFile(img_path, 'rb') as fid:\n",
    "        encoded_jpg = fid.read()\n",
    "    encoded_jpg_io = io.BytesIO(encoded_jpg)\n",
    "    image = PIL.Image.open(encoded_jpg_io)\n",
    "    \n",
    "   # 读取图片信息\n",
    "    if image.format != 'JPEG':\n",
    "        raise ValueError('Image format not JPEG')\n",
    "    key = hashlib.sha256(encoded_jpg).hexdigest()\n",
    "\n",
    "    width = int(data['size']['width'])\n",
    "    height = int(data['size']['height'])\n",
    "\n",
    "    xmins = []\n",
    "    ymins = []\n",
    "    xmaxs = []\n",
    "    ymaxs = []\n",
    "    classes = []\n",
    "    classes_text = []\n",
    "    truncated = []\n",
    "    poses = []\n",
    "    difficult_obj = []\n",
    " \n",
    "    for obj in data['object']:\n",
    "        difficult = bool(int(obj['difficult']))\n",
    "        if ignore_difficult_instances and difficult:\n",
    "            continue\n",
    "        difficult_obj.append(int(difficult))\n",
    "\n",
    "        xmins.append(xmin / width)\n",
    "        ymins.append(ymin / height)\n",
    "        xmaxs.append(xmax / width)\n",
    "        ymaxs.append(ymax / height)\n",
    "        class_name = get_class_name_from_filename(data['filename'])\n",
    "        classes_text.append(class_name.encode('utf8'))\n",
    "        classes.append(label_map_dict[class_name])\n",
    "        truncated.append(int(obj['truncated']))\n",
    "        poses.append(obj['pose'].encode('utf8'))\n",
    "\n",
    "    # 将解析出的信息写成TFRecord\n",
    "    feature_dict = {\n",
    "      'image/height': dataset_util.int64_feature(height),\n",
    "      'image/width': dataset_util.int64_feature(width),\n",
    "      'image/filename': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/source_id': dataset_util.bytes_feature(data['filename'].encode('utf8')),\n",
    "      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),\n",
    "      'image/encoded': dataset_util.bytes_feature(encoded_jpg),\n",
    "      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),\n",
    "      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),\n",
    "      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),\n",
    "      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),\n",
    "      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),\n",
    "      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),\n",
    "      'image/object/class/label': dataset_util.int64_list_feature(classes),\n",
    "      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),\n",
    "      'image/object/truncated': dataset_util.int64_list_feature(truncated),\n",
    "      'image/object/view': dataset_util.bytes_list_feature(poses),\n",
    "    }\n",
    "\n",
    "    example = tf.train.Example(features=tf.train.Features(feature=feature_dict))\n",
    "    return example\n",
    "\n",
    "\n",
    "def create_tf_record(output_filename,\n",
    "                     label_map_dict,\n",
    "                     annotations_dir,\n",
    "                     image_dir,\n",
    "                     examples):\n",
    "    \"\"\"Creates a TFRecord file from examples.\n",
    "\n",
    "      Args:\n",
    "        output_filename: Path to where output file is saved.\n",
    "        label_map_dict: The label map dictionary.\n",
    "        annotations_dir: Directory where annotation files are stored.\n",
    "        image_dir: Directory where image files are stored.\n",
    "        examples: Examples to parse and save to tf record.\n",
    "        faces_only: If True, generates bounding boxes for pet faces.  Otherwise\n",
    "        generates bounding boxes (as well as segmentations for full pet bodies).\n",
    "    \"\"\"\n",
    "    \n",
    "    writer = tf.python_io.TFRecordWriter(output_filename)\n",
    "    for idx, example in enumerate(examples):\n",
    "        if idx % 100 == 0:\n",
    "            logging.info('On image %d of %d', idx, len(examples))\n",
    "        xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')\n",
    "\n",
    "        if not os.path.exists(xml_path):\n",
    "            logging.warning('Could not find %s, ignoring example.', xml_path)\n",
    "            continue\n",
    "            \n",
    "        with tf.gfile.GFile(xml_path, 'r') as fid:\n",
    "            xml_str = fid.read()\n",
    "        xml = etree.fromstring(xml_str)\n",
    "        data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']\n",
    "\n",
    "        try:\n",
    "            tf_example = dict_to_tf_example(data, label_map_dict, image_dir)\n",
    "            writer.write(tf_example.SerializeToString())\n",
    "        except ValueError:\n",
    "            logging.warning('Invalid example: %s, ignoring.', xml_path)\n",
    "\n",
    "    writer.close()   \n",
    "\n",
    "# TODO(derekjchow): Add test for pet/PASCAL main files.\n",
    "def main(_):\n",
    "    data_dir = FLAGS.data_dir\n",
    "    label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)\n",
    "\n",
    "    logging.info('Reading from Pet dataset.')\n",
    "    image_dir = os.path.join(data_dir, 'images')\n",
    "    annotations_dir = os.path.join(data_dir, 'annotations')\n",
    "    examples_path = os.path.join(annotations_dir, 'trainval.txt')\n",
    "    examples_list = dataset_util.read_examples_list(examples_path)\n",
    "\n",
    "    # Test images are not included in the downloaded data set, so we shall perform\n",
    "    # our own split.\n",
    "    random.seed(42)\n",
    "    random.shuffle(examples_list)\n",
    "    num_examples = len(examples_list)\n",
    "    num_train = int(0.7 * num_examples)\n",
    "    train_examples = examples_list[:num_train]\n",
    "    val_examples = examples_list[num_train:]\n",
    "    logging.info('%d training and %d validation examples.',\n",
    "               len(train_examples), len(val_examples))\n",
    "\n",
    "    train_output_path = os.path.join(FLAGS.output_dir, 'pet_train.record')\n",
    "    val_output_path = os.path.join(FLAGS.output_dir, 'pet_val.record')\n",
    "\n",
    "    create_tf_record(train_output_path, label_map_dict, annotations_dir,image_dir, train_examples)\n",
    "    create_tf_record(val_output_path, label_map_dict, annotations_dir, image_dir, val_examples)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    tf.app.run()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
